{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 人脸检测和识别推理流程\n",
    "\n",
    "以下示例展示了如何使用facenet_pytorch python包，在使用在VGGFace2数据集上预训练的Inception Resnet V1模型上对图像数据集执行人脸检测和识别。\n",
    "\n",
    "以下PyTorch方法已包含：\n",
    "\n",
    "* 数据集\n",
    "* 数据加载器\n",
    "* GPU/CPU处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:23:56.432302Z",
     "start_time": "2023-07-20T11:23:54.333487Z"
    }
   },
   "outputs": [],
   "source": [
    "from facenet_pytorch import MTCNN, InceptionResnetV1\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "workers = 0 if os.name == 'nt' else 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 判断是否有nvidia GPU可用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:23:56.470217Z",
     "start_time": "2023-07-20T11:23:56.436290Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "在该设备上运行: cuda:0\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print('在该设备上运行: {}'.format(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义MTCNN模块\n",
    "\n",
    "为了说明，默认参数已显示，但不是必需的。请注意，由于MTCNN是一组神经网络和其他代码，因此必须以以下方式传递设备，以便在需要内部复制对象时启用。\n",
    "\n",
    "查看`help(MTCNN)`获取更多细节。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:23:56.587926Z",
     "start_time": "2023-07-20T11:23:56.472212Z"
    }
   },
   "outputs": [],
   "source": [
    "mtcnn = MTCNN(\n",
    "    image_size=160, margin=0, min_face_size=20,\n",
    "    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
    "    device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义Inception Resnet V1模块\n",
    "\n",
    "设置classify=True以使用预训练分类器。对于本示例，我们将使用该模型输出嵌入/卷积特征。请注意，在推理过程中，将模型设置为`eval`模式非常重要。\n",
    "\n",
    "查看`help(InceptionResnetV1)`获取更多细节。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:23:56.988662Z",
     "start_time": "2023-07-20T11:23:56.588910Z"
    }
   },
   "outputs": [],
   "source": [
    "resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义数据集和数据加载器\n",
    "\n",
    "我们向数据集添加了`idx_to_class`属性，以便稍后轻松重编标签索引为身份名称。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:23:56.995647Z",
     "start_time": "2023-07-20T11:23:56.989657Z"
    }
   },
   "outputs": [],
   "source": [
    "def collate_fn(x):\n",
    "    return x[0]\n",
    "\n",
    "dataset = datasets.ImageFolder('../data/test_images')\n",
    "dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n",
    "loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 执行MTCNN人脸检测\n",
    "\n",
    "迭代DataLoader对象并检测每个人脸及其关联的检测概率。如果检测到脸部，`MTCNN`的前向方法将返回裁剪到检测到的脸部的图像。默认情况下，仅返回检测到的单个面部-要使`MTCNN`返回所有检测到的面部，请在上面创建MTCNN对象时设置`keep_all=True`。\n",
    "\n",
    "要获取边界框而不是裁剪的人脸图像，可以调用较低级别的`mtcnn.detect()`函数。查看`help(mtcnn.detect)`获取详细信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:24:00.971832Z",
     "start_time": "2023-07-20T11:23:56.996640Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "检测到的人脸及其概率: 0.999983\n",
      "检测到的人脸及其概率: 0.999934\n",
      "检测到的人脸及其概率: 0.999733\n",
      "检测到的人脸及其概率: 0.999880\n",
      "检测到的人脸及其概率: 0.999992\n"
     ]
    }
   ],
   "source": [
    "aligned = []\n",
    "names = []\n",
    "for x, y in loader:\n",
    "    x_aligned, prob = mtcnn(x, return_prob=True)\n",
    "    if x_aligned is not None:\n",
    "        print('检测到的人脸及其概率: {:8f}'.format(prob))\n",
    "        aligned.append(x_aligned)\n",
    "        names.append(dataset.idx_to_class[y])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 计算图像嵌入\n",
    "\n",
    "MTCNN将返回所有面部图像的相同大小，从而可以使用Resnet识别模块轻松进行批处理。在这里，由于我们只有一些图像，因此我们构建一个单个批次并对其执行推理。\n",
    "\n",
    "对于实际数据集，代码应修改为控制传递给Resnet的批处理大小，特别是如果在GPU上处理。对于重复测试，最好将人脸检测（使用MTCNN）与嵌入或分类（使用InceptionResnetV1）分开，因为剪切面或边界框的计算可以一次执行，检测到的面部可以保存供将来使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:24:01.037605Z",
     "start_time": "2023-07-20T11:24:00.973820Z"
    }
   },
   "outputs": [],
   "source": [
    "aligned = torch.stack(aligned).to(device)\n",
    "embeddings = resnet(aligned).detach().cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 打印各类别的距离矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-20T11:24:01.050571Z",
     "start_time": "2023-07-20T11:24:01.038602Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>angelina_jolie</th>\n",
       "      <th>bradley_cooper</th>\n",
       "      <th>kate_siegel</th>\n",
       "      <th>paul_rudd</th>\n",
       "      <th>shea_whigham</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>angelina_jolie</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.447480</td>\n",
       "      <td>0.887728</td>\n",
       "      <td>1.429847</td>\n",
       "      <td>1.399073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>bradley_cooper</th>\n",
       "      <td>1.447480</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.313749</td>\n",
       "      <td>1.013447</td>\n",
       "      <td>1.038684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>kate_siegel</th>\n",
       "      <td>0.887728</td>\n",
       "      <td>1.313749</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.388377</td>\n",
       "      <td>1.379655</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>paul_rudd</th>\n",
       "      <td>1.429847</td>\n",
       "      <td>1.013447</td>\n",
       "      <td>1.388377</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.100502</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>shea_whigham</th>\n",
       "      <td>1.399073</td>\n",
       "      <td>1.038684</td>\n",
       "      <td>1.379655</td>\n",
       "      <td>1.100502</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                angelina_jolie  bradley_cooper  kate_siegel  paul_rudd  \\\n",
       "angelina_jolie        0.000000        1.447480     0.887728   1.429847   \n",
       "bradley_cooper        1.447480        0.000000     1.313749   1.013447   \n",
       "kate_siegel           0.887728        1.313749     0.000000   1.388377   \n",
       "paul_rudd             1.429847        1.013447     1.388377   0.000000   \n",
       "shea_whigham          1.399073        1.038684     1.379655   1.100502   \n",
       "\n",
       "                shea_whigham  \n",
       "angelina_jolie      1.399073  \n",
       "bradley_cooper      1.038684  \n",
       "kate_siegel         1.379655  \n",
       "paul_rudd           1.100502  \n",
       "shea_whigham        0.000000  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]\n",
    "pd.DataFrame(dists, columns=names, index=names)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
